
#ifndef __inc_cassandra
#define __inc_cassandra

#include "../../PGBasics.hh"
#include "../../Sampler.hh"
/*
 *  Cassandra.hh
 *  
 *
 *  Created by Owen Thomas on 16/06/06.
 *  Copyright 2006 __MyCompanyName__. All rights reserved.
 *
 */
#include <ostream>
#include <map>

using namespace libpg;

class Cassandra : public Simulator
{
	public:
	
		//
		//Immutable data (set at constructor only):
		//
		
		SparseMatrix* actionTransitions;
		
		//map<int,SparseMatrix*> actionTransitions;		//start-state -> end-state transition probabilities.
												//One square matrix for each action.
									
		SparseMatrix* actionObservations;
		//map<int,SparseMatrix*> actionObservations;	//end-state -> obvservation probabilities.
												//One rectangular matrix for each action.
		
		SparseMatrix** rewards;
		//map<int, map<int, SparseMatrix*> > rewards;	//action -> end->state -> end-state * observation
		
		unsigned int numActions;		 //The number of actions.
		unsigned int numStates;		 //The number of states
		unsigned int numObservations; //The number of observations.
		
		float discount;
		
		//
		//Client Mutuable (indirectly through doAction):
		//
		
		int currState;		 //Current state
		int currObservation; //Current observation
		int currAction;		 //Current action.
		int lastState;
		
		//Cheating - directly mutable.
		bool cheat;
		
		
		void init (SparseMatrix* actionTransitions,
			SparseMatrix* actionObservations,
			SparseMatrix** rewards,
			int numActions, int numStates, 
			int numObservations, int startState, float discount);
	//public:
		
		/**
		 * Construct a Cassandra Simulator with the given
		 * action and observation probability matricies.
		 *
		 * @param actionTransitions, an array, of length numActions, of
		 * Action Transition probabilities from a start state (row index) to
		 * end state (column index). Each row must sum to 1.0.
		 *
		 * @param actionObservations, an array, of length numActions, of
		 * Observation probabilities from a state (row index) to an end state
		 * (column index). Each row must sum to 1.0.
		 *
		 */
		Cassandra (SparseMatrix* actionTransitions,
			SparseMatrix* actionObservations,
			SparseMatrix** rewards,
			int numActions, int numStates, 
			int numObservations, 
			float discount = 0.95);
			
		Cassandra (SparseMatrix* actionTransitions,
			SparseMatrix* actionObservations,
			SparseMatrix** rewards,
			int numActions, int numStates, 
			int numObservations, int startState,
			float discount = 0.95);
		
		virtual ~Cassandra () {
			//delete maps? maybe.
		}
		
		
		/**
		 To-do: Single, global observation matrix constructor:
		 
		 Cassandra (SparseMatrix* actionTransitions,
			SparseMatrix observations,
			int numActions, int numStates, int numObservations);
			
		*/
		
		/**
		 * Perform the specified action, generating a 
		 * state and an observation.
		 * 
		 * @param action, the action to perform. This is assumed to be an integer
		 * and will be cast as such.
		 */
		virtual int doAction(Vector& action);
		
		/**
		 * Perform the action with the specified index, generating
		 * a state and observation.
		 *
		 * @parram action, the index of the action to perform.
		 */
		virtual int doAction (int action);
		
		/**
		 * Returns an observation generated by the previous call
		 * to doAction.
		 *
		 * Observations are found when the action is performed, not
		 * on a call to getObservation. Multiple calls to 
		 * getObservation, without a call to doAction, will return
		 * the same observation.
		 */
		virtual void getObservation (Observation& obs);
		
		/**
		 * Returns the last generated reward.
		 */
		virtual void getReward (Vector& rewards);
		
		virtual int getReward ();
		//
		//Immutable method calls.
		//
		
		virtual float getDiscount () {
			return discount;
		}
		
		virtual int getNumActions () {
			return numActions;
		}
		
		virtual int getObservation () {
			return currObservation;
		}
		
		virtual int getNumObs () {
			if(!cheat)
				return numObservations;
			else
				return numStates;
		}
		
                int getNumStates () {
				return numStates;
		}
		
		virtual void setCheat (bool cheat = true);
		
		
		//
		//semi-static getter implementations.
		//
		
		
		//Observations are just a single integer.
		virtual int getObsRows () {
			return 1;
		}
		
		virtual int getObsCols () {
			return 1;
		}
		
		//1 Agent
		virtual int getAgents () {
			return 1;
		}
		
		virtual int getActionDim () {
			return 1;
		}
		
		virtual int getRewardDim () {
			return 1;
		}
		
		virtual void print ();
};

#endif
